2024.7.10 optimを用いた勾配法による最適化【torch】
関数$ y(x)=x^2 -2x-3の最小値を初期値$ x_0 = 2.0から最急降下法を用いて探索する。
code:optim1.py
import torch as pt
import matplotlib.pyplot as plt
import torch.optim as optim
def func(x):
return x**2 -2*x - 3
# >>> パラメータ
x0 = 2.0 # 初期値をint型で用意
alpha = 0.1 # 学習率
hist_x = []
# <<<
x = pt.tensor(x0, dtype=pt.float, requires_grad=True)
optimizer = optim.SGD(x, lr=alpha) for i in range(50):
hist_x.append(x.detach().clone())
optimizer.zero_grad()
y = func(x)
y.backward() # x.grad を求める
optimizer.step() # x.grad を用いて x を更新
hist_x.append(x.detach().clone()) # 最後の1つをアペンド
hist_x = pt.stack(hist_x)
plt.plot(hist_x)
plt.show()
可視化してみる。
code:optim2.py
import torch as pt
import numpy as np
import matplotlib.pyplot as plt
import torch.optim as optim
def func(x):
return x**2 -2*x - 3
# >>> パラメータ
x0 = 2.0
alpha = 0.1
hist_x = []
hist_y = []
# <<<
x = pt.tensor(x0, dtype=pt.float, requires_grad=True)
y = func(x)
optimizer = optim.SGD(x, lr=alpha) for i in range(50):
hist_x.append(x.detach().numpy().copy())
optimizer.zero_grad()
y = func(x)
y.backward()
optimizer.step()
hist_y.append(y.detach().clone())
hist_x = np.stack(hist_x)
hist_y = np.stack(hist_y)
x = np.linspace(0.5, 2)
y = func(x)
plt.plot(x, y)
plt.plot(hist_x, hist_y, '*')
plt.grid()
plt.show()
https://scrapbox.io/files/668ddba3c18aa5001cde94ee.png
4次関数の例
$ y(x) = (2x+1)(x+2)(x-1)xの極小値を初期値$ x_0 = 0より探索。
code:opt3.py
import torch as pt
import numpy as np
import matplotlib.pyplot as plt
import torch.optim as optim
def func(x):
return (2*x + 1)*(x + 2)*(x - 1)*x
# >>> パラメータ
x0 = 0.0
alpha = 0.01
hist_x = []
hist_y = []
# <<<
x = pt.tensor(x0, dtype=pt.float, requires_grad=True)
y = func(x)
optimizer = optim.SGD(x, lr=alpha) for i in range(50):
hist_x.append(x.detach().numpy().copy())
optimizer.zero_grad()
y = func(x)
y.backward()
optimizer.step()
hist_y.append(y.detach().numpy().copy())
hist_x = np.stack(hist_x)
hist_y = np.stack(hist_y)
x = np.linspace(-2.1, 1.1) # デフォルトでは50点
y = func(x)
plt.plot(x, y)
plt.plot(hist_x, hist_y, '*')
plt.grid()
plt.show()
結果は初期値に依存することがよくわかる。
https://scrapbox.io/files/668dddb3859fb3001cb76f9b.png